import torch
import torch.nn as nn
from .Conditioner import Conditioner
import networkx as nx

# This implementation of DAGConditioner is modified based on: https://github.com/AWehenkel/Graphical-Normalizing-Flows/blob/master/models/Conditionners/DAGConditioner.py

device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")
class DAGMLP(nn.Module):
    def __init__(self, in_size, hidden, out_size, cond_in=0):
        super(DAGMLP, self).__init__()
        in_size = in_size
        l1 = [in_size + cond_in] + hidden
        l2 = hidden + [out_size]
        layers = []
        for h1, h2 in zip(l1, l2):
            layers += [nn.Linear(h1, h2), nn.ReLU()]
        layers.pop()
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class DAGConditioner(Conditioner):
    def __init__(self, in_size, hidden, out_size, cond_in=0, soft_thresholding=False, h_thresh=0., gumble_T=1.,
                 hot_encoding=True, l1=3, nb_epoch_update=1, A_prior=None):
        super(DAGConditioner, self).__init__()
        self.A_prior=A_prior
        self.d=in_size
        if A_prior is None:
            self.A = nn.Parameter(torch.ones(in_size, in_size) * 1.5 + torch.randn((in_size, in_size)) * .02).to(device)
        else:
            self.A = nn.Parameter(A_prior).to(device)
        self.A_fix_idx = A_prior == 0
        self.in_size = in_size
        self.exponent = self.in_size % 50
        self.s_thresh = soft_thresholding
        self.h_thresh = h_thresh
        self.stoch_gate = True
        self.noise_gate = False
        in_net = in_size*2 if hot_encoding else in_size
        if issubclass(type(hidden), nn.Module):
            self.embedding_net = hidden
        else:
            self.embedding_net = DAGMLP(in_net, hidden, out_size, cond_in)
        self.gumble = True
        self.hutchinson = False
        self.gumble_T = gumble_T
        self.hot_encoding = hot_encoding
        with torch.no_grad():
            self.constrainA(h_thresh)
        # Buffers related to the optimization of the constraints on A
        self.register_buffer("lambd", torch.tensor(.0))
        self.register_buffer("c", torch.tensor(1e-3))
        self.register_buffer("eta", torch.tensor(10.))
        self.register_buffer("gamma", torch.tensor(.9))
        self.register_buffer("lambd", torch.tensor(.0))
        self.register_buffer("l1_weight", torch.tensor(l1))
        self.register_buffer("dag_const", torch.tensor(1.))
        #self.register_buffer("alpha_factor", torch.tensor(1.))
        self.alpha_factor = 1.
        self.d = in_size
        self.tol = 1e-30
        self.register_buffer("alpha", self.getAlpha())
        self.register_buffer("prev_trace", self.get_power_trace())
        self.nb_epoch_update = nb_epoch_update
        self.no_update = 0
        self.is_invertible = False#torch.tensor(False)
  
    def set_zero_grad(self):
        if self.A_prior is None:
            pass
        else:
            for i in range(self.d):
                for j in range(self.d):
                    #print(i,j)
                    if self.A_fix_idx[i, j]:
                        self.A.grad.data[i, j].zero_()
                        
    def mask(self, x): # Ax
        x = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.A.unsqueeze(0).expand(x.shape[0], -1, -1))\
                .view(x.shape[0] * self.in_size, -1)
        return x
    
    def getAlpha(self):
        with torch.no_grad():
            _, S, _ = torch.svd(self.A**2, compute_uv=False)
            alpha = 1/(torch.max(S) * self.in_size)
        alpha = torch.tensor(1./self.in_size)
        return alpha

    def get_dag(self):
        return self

    def post_process(self, zero_threshold=None):
        if zero_threshold is None:
            zero_threshold = .1
            G = nx.from_numpy_matrix((self.soft_thresholded_A().data.clone().abs() > zero_threshold).float().detach().cpu().numpy(), create_using=nx.DiGraph)
            while not nx.is_directed_acyclic_graph(G):
                zero_threshold += .05
                G = nx.from_numpy_matrix(
                    (self.soft_thresholded_A().data.clone().abs() > zero_threshold).float().detach().cpu().numpy(),
                    create_using=nx.DiGraph)
        self.stoch_gate = False
        self.noise_gate = False
        self.s_thresh = False
        self.h_thresh = 0.
        self.A.data = (self.soft_thresholded_A().data.clone().abs() > zero_threshold).float()
        self.A *= 1. - torch.eye(self.in_size, device=self.A.device)
        self.A.requires_grad = False
        self.A.grad = None

    def stochastic_gate(self, importance):
        if self.gumble:#self.gumble = True
            # Gumble soft-max gate
            temp = self.gumble_T#gumble_T=1.
            epsilon = 1e-6
            g1 = -torch.log(-torch.log(torch.rand(importance.shape, device=self.A.device)))
            g2 = -torch.log(-torch.log(torch.rand(importance.shape, device=self.A.device)))
            z1 = torch.exp((torch.log(importance + epsilon) + g1)/temp)
            z2 = torch.exp((torch.log(1 - importance + epsilon) + g2)/temp)
            return z1 / (z1 + z2)

        else:
            beta_1, beta_2 = 3., 10.
            sigma = beta_1/(1. + beta_2*torch.sqrt((importance - .5)**2.))
            mu = importance
            z = torch.randn(importance.shape, device=self.A.device) * sigma + mu + .25
            return torch.relu(z.clamp_max(1.))

    def noiser_gate(self, x, importance):
        noise = torch.randn(importance.shape, device=self.A.device) * torch.sqrt((1 - importance)**2)
        return importance*(x + noise)

    def soft_thresholded_A(self):
        return 2*(torch.sigmoid(2*(self.A**2)) -.5)

    def hard_thresholded_A(self):
        if self.s_thresh:
            return self.soft_thresholded_A()*(self.soft_thresholded_A() > self.h_thresh).float()
        return self.A**2 * (self.A**2 > self.h_thresh).float()
    
    def DAGflow_A(self):
        for i in range(self.in_size):
             for j in range(self.in_size):
                if i >= j:
                    self.A[i][j]=0
                else:
                    self.A[i][j]=self.A[i][j]
        return self.A

    def forward(self, x, context=None):
        if self.h_thresh > 0:
            if self.stoch_gate:
                e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.stochastic_gate(self.hard_thresholded_A().unsqueeze(0)
                                                                                        .expand(x.shape[0], -1, -1)))\
                    .view(x.shape[0] * self.in_size, -1)
            elif self.noise_gate:
                e = self.noiser_gate(x.unsqueeze(1).expand(-1, self.in_size, -1),
                                     self.hard_thresholded_A().unsqueeze(0)
                                     .expand(x.shape[0], -1, -1))\
                    .view(x.shape[0] * self.in_size, -1)
            else:
                e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.hard_thresholded_A().unsqueeze(0)
                     .expand(x.shape[0], -1, -1)).view(x.shape[0] * self.in_size, -1)
        elif self.s_thresh:
            if self.stoch_gate:
                e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.stochastic_gate(self.soft_thresholded_A().unsqueeze(0)
                                                                                        .expand(x.shape[0], -1, -1))).view(x.shape[0] * self.in_size, -1)
            elif self.noise_gate:
                e = self.noiser_gate(x.unsqueeze(1).expand(-1, self.in_size, -1),
                                     self.soft_thresholded_A().unsqueeze(0).expand(x.shape[0], -1, -1))\
                    .view(x.shape[0] * self.in_size, -1)
            else:
                e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.soft_thresholded_A().unsqueeze(0)
                     .expand(x.shape[0], -1, -1)).view(x.shape[0] * self.in_size, -1)
        else:
            e = (x.unsqueeze(1).expand(-1, self.in_size, -1) * self.A.unsqueeze(0).expand(x.shape[0], -1, -1))\
                .view(x.shape[0] * self.in_size, -1).to(device)

        if self.hot_encoding:
            hot_encoding = torch.eye(self.in_size, device=self.A.device).unsqueeze(0).expand(x.shape[0], -1, -1)\
                .contiguous().view(-1, self.in_size).to(device)
            # ORIGINAL CODE
            # e = self.embedding_net(e)
            # full_e = torch.cat((e, hot_encoding), 1).view(x.shape[0], self.in_size, -1)
            # # TODO Add context
            # return full_e
            # END ORIGINAL CODE

            # ASIC'S ATTEMPT TO FIX DAG CONDITIONER ERROR
            return self.embedding_net(torch.cat((e, hot_encoding), 1)).view(x.shape[0], self.in_size, -1).to(device)
            # END ASIC'S ATTEMPT

        return self.embedding_net(e).view(x.shape[0], self.in_size, -1).to(device)(x.shape[0], -1)

    def constrainA(self, zero_threshold=0):
          self.A *= (self.A.clone().abs() > zero_threshold).float()
          print(self.A)
          return 

    def get_power_trace(self):
        alpha = min(1., self.alpha)
        alpha *= self.alpha_factor
        if self.hutchinson != 0:
            h_iter = self.hutchinson
            trace = 0.
            I = torch.eye(self.in_size, device=self.A.device)
            for j in range(h_iter):
                e0 = torch.randn(self.in_size, 1).to(self.A.device)
                e = e0
                for i in range(self.in_size):
                    e = (I + alpha * self.A ** 2) @ e

                trace += (e0 * e).sum()
            return trace / h_iter - self.in_size

        B = (torch.eye(self.in_size, device=self.A.device) + alpha * self.A ** 2)
        M = torch.matrix_power(B, self.exponent)
        return torch.diag(M).sum() - self.in_size

    